{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. 准备数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "TRAIN_DATA_NUM = 1000\n",
    "TEST_DATA_NUM = 100\n",
    "\n",
    "TRAIN_DATA_PATH = \"data/train.jsonl\"\n",
    "TRAIN_DATA_WITH_STEPS_PATH = \"data/train_with_steps.jsonl\"\n",
    "TEST_DATA_PATH = \"data/test.jsonl\"\n",
    "\n",
    "RL_DATA_PATH = \"data/rl_data.jsonl\"\n",
    "RL_DATA_NUM = 1000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.1 合成 Countdown 数据集\n",
    "\n",
    "限制：\n",
    "\n",
    "1. 仅生成顺序 ops 表达式，无法生成 (A+B)/(C-D) 这种表达式\n",
    "2. 每个数字均被使用，且每个数字只使用一次。\n",
    "3. 中间计算结果必须是整数。\n",
    "4. 必须有解。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tqdm\n",
    "import random\n",
    "import json\n",
    "from typing import List, Tuple\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_solution_to_expression(solution):\n",
    "    \"\"\"将计算步骤转换为标准数学表达式\"\"\"\n",
    "    if not solution:\n",
    "        return \"\"\n",
    "    \n",
    "    # 第一步\n",
    "    expr = f\"({solution[0][0]} {solution[0][1]} {solution[0][2]})\"\n",
    "    \n",
    "    # 后续步骤\n",
    "    for step in solution[1:]:\n",
    "        expr = f\"({expr} {step[1]} {step[2]})\"\n",
    "    \n",
    "    # 去掉最外层的括号\n",
    "    if expr.startswith('(') and expr.endswith(')'):\n",
    "        expr = expr[1:-1]\n",
    "    \n",
    "    return expr\n",
    "\n",
    "def gen_dataset(\n",
    "    num_samples: int,\n",
    "    num_operands: int = 4,\n",
    "    max_target: int = 999,\n",
    "    min_number: int = 1,\n",
    "    max_number: int = 999,\n",
    "    operations: List[str] = ['+', '-', '*', '/'],\n",
    "    op_weights: dict = {'*': 0.2, '/': 0.7, '+': 0.05, '-': 0.05},\n",
    "    small_number_ratio: float = 0.8,  # 80%的数字选自小范围\n",
    "    small_range_ratio: float = 0.1,   # 小范围为总范围的前10%\n",
    "    seed_value: int = 42,\n",
    ") -> List[Tuple]:\n",
    "    random.seed(seed_value)\n",
    "    samples = []\n",
    "    \n",
    "    # 计算小范围的上限\n",
    "    small_range_upper = min_number + int((max_number - min_number) * small_range_ratio)\n",
    "    \n",
    "    for _ in tqdm.tqdm(range(num_samples)):\n",
    "        while True:\n",
    "            # 生成随机数，80%的数字从小范围中选择\n",
    "            numbers = []\n",
    "            for _ in range(num_operands):\n",
    "                if random.random() < small_number_ratio:\n",
    "                    # 从小范围选择\n",
    "                    num = random.randint(min_number, small_range_upper)\n",
    "                else:\n",
    "                    # 从全范围选择\n",
    "                    num = random.randint(min_number, max_number)\n",
    "                numbers.append(num)\n",
    "            \n",
    "            # 尝试生成有效表达式\n",
    "            solution = []\n",
    "            nums_left = numbers.copy()\n",
    "            valid = True\n",
    "            \n",
    "            # 取第一个数作为初始值\n",
    "            n1 = nums_left.pop(0)\n",
    "            \n",
    "            # 依次处理剩余数字\n",
    "            while nums_left and valid:\n",
    "                n2 = nums_left.pop(0)\n",
    "                \n",
    "                # 根据权重选择操作符\n",
    "                weighted_ops = []\n",
    "                for op in operations:\n",
    "                    weight = op_weights.get(op, 1.0 / len(operations))\n",
    "                    weighted_ops.extend([op] * int(weight * 100))\n",
    "                \n",
    "                random.shuffle(weighted_ops)\n",
    "                valid_op_found = False\n",
    "                \n",
    "                # 尝试所有操作符\n",
    "                tried_ops = set()\n",
    "                while weighted_ops and not valid_op_found:\n",
    "                    op = weighted_ops.pop()\n",
    "                    if op in tried_ops:\n",
    "                        continue\n",
    "                    tried_ops.add(op)\n",
    "                    \n",
    "                    if op == '+':\n",
    "                        result = n1 + n2\n",
    "                        valid_op_found = True\n",
    "                    elif op == '-':\n",
    "                        result = n1 - n2\n",
    "                        valid_op_found = True\n",
    "                    elif op == '*':\n",
    "                        result = n1 * n2\n",
    "                        valid_op_found = True\n",
    "                    elif op == '/' and n2 != 0 and n1 % n2 == 0:\n",
    "                        result = n1 // n2\n",
    "                        valid_op_found = True\n",
    "                    else:\n",
    "                        continue\n",
    "                        \n",
    "                    if valid_op_found:\n",
    "                        solution.append((n1, op, n2, result))\n",
    "                        n1 = result\n",
    "                        break\n",
    "                \n",
    "                # 如果所有操作符都尝试过仍未找到有效操作符\n",
    "                if not valid_op_found:\n",
    "                    valid = False\n",
    "            \n",
    "            # 如果生成了有效表达式\n",
    "            if valid:\n",
    "                target = n1  # 最终结果\n",
    "                if target > 0 and target <= max_target:  # 确保结果在有效范围内\n",
    "                    random.shuffle(numbers) # 打乱顺序，避免和答案顺序一致\n",
    "                    samples.append({\n",
    "                        \"numbers\": numbers,\n",
    "                        \"target\": target,\n",
    "                        \"ground_truth_solution\": convert_solution_to_expression(solution),\n",
    "                        \"solution_steps\": solution,\n",
    "                    })\n",
    "                    break\n",
    "    \n",
    "    # 缺点，这种方法只能生成顺序的表达式，无法生成比如 (A+B)/(C-D) 这种表达式\n",
    "    return samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████| 2100/2100 [00:03<00:00, 622.10it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'numbers': [29, 11, 27, 10],\n",
       "  'target': 524,\n",
       "  'ground_truth_solution': '((29 - 10) * 27) + 11',\n",
       "  'solution_steps': [(29, '-', 10, 19),\n",
       "   (19, '*', 27, 513),\n",
       "   (513, '+', 11, 524)]},\n",
       " {'numbers': [34, 35, 5, 56],\n",
       "  'target': 272,\n",
       "  'ground_truth_solution': '((56 * 5) / 35) * 34',\n",
       "  'solution_steps': [(56, '*', 5, 280), (280, '/', 35, 8), (8, '*', 34, 272)]},\n",
       " {'numbers': [34, 5, 8, 15],\n",
       "  'target': 816,\n",
       "  'ground_truth_solution': '((8 * 15) * 34) / 5',\n",
       "  'solution_steps': [(8, '*', 15, 120),\n",
       "   (120, '*', 34, 4080),\n",
       "   (4080, '/', 5, 816)]}]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "samples = gen_dataset(TRAIN_DATA_NUM + TEST_DATA_NUM + RL_DATA_NUM)\n",
    "samples[:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(TRAIN_DATA_WITH_STEPS_PATH, \"w\") as f:\n",
    "    for sample in samples[:TRAIN_DATA_NUM]:\n",
    "        f.write(json.dumps(sample) + \"\\n\")\n",
    "\n",
    "with open(TRAIN_DATA_PATH, \"w\") as f:\n",
    "    for sample in samples[:TRAIN_DATA_NUM]:\n",
    "        del sample[\"solution_steps\"]\n",
    "        f.write(json.dumps(sample) + \"\\n\")\n",
    "\n",
    "with open(TEST_DATA_PATH, \"w\") as f:\n",
    "    for sample in samples[TRAIN_DATA_NUM:TRAIN_DATA_NUM + TEST_DATA_NUM]:\n",
    "        del sample[\"solution_steps\"]\n",
    "        f.write(json.dumps(sample) + \"\\n\")\n",
    "\n",
    "with open(RL_DATA_PATH, \"w\") as f:\n",
    "    for sample in samples[TRAIN_DATA_NUM + TEST_DATA_NUM:]:\n",
    "        del sample[\"solution_steps\"]\n",
    "        f.write(json.dumps(sample) + \"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 小作业\n",
    "\n",
    "目前只能生成顺序的表达式，无法生成比如 (A+B)/(C-D) 这种表达式，请修改为可以生成这种表达式。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. 生成简单一些的数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1100/1100 [00:00<00:00, 2091.09it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'numbers': [77, 7, 77],\n",
       "  'target': 7,\n",
       "  'ground_truth_solution': '(77 / 77) * 7',\n",
       "  'solution_steps': [(77, '/', 77, 1), (1, '*', 7, 7)]},\n",
       " {'numbers': [57, 63, 11],\n",
       "  'target': 131,\n",
       "  'ground_truth_solution': '(63 + 57) + 11',\n",
       "  'solution_steps': [(63, '+', 57, 120), (120, '+', 11, 131)]},\n",
       " {'numbers': [11, 15, 22],\n",
       "  'target': 17,\n",
       "  'ground_truth_solution': '(22 / 11) + 15',\n",
       "  'solution_steps': [(22, '/', 11, 2), (2, '+', 15, 17)]}]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "samples_simple = gen_dataset(RL_DATA_NUM + TEST_DATA_NUM, num_operands=3)\n",
    "samples_simple[:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"data/rl_data_simple.jsonl\", \"w\") as f:\n",
    "    for sample in samples_simple[:RL_DATA_NUM]:\n",
    "        del sample[\"solution_steps\"]\n",
    "        f.write(json.dumps(sample) + \"\\n\")\n",
    "\n",
    "with open(\"data/test_simple.jsonl\", \"w\") as f:\n",
    "    for sample in samples_simple[RL_DATA_NUM:]:\n",
    "        del sample[\"solution_steps\"]\n",
    "        f.write(json.dumps(sample) + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:05<00:00, 1859.08it/s]\n"
     ]
    }
   ],
   "source": [
    "samples_simple_10k = gen_dataset(10000, num_operands=3, seed_value=888)\n",
    "with open(\"data/rl_data_simple_10k.jsonl\", \"w\") as f:\n",
    "    for sample in samples_simple_10k[:10000]:\n",
    "        del sample[\"solution_steps\"]\n",
    "        f.write(json.dumps(sample) + \"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. （可选）使用现成的数据集\n",
    "\n",
    "从 Jiayi-Pan/Countdown-Tasks-3to4 或者类似的数据集中抽取，这些数据集确保结果可解，所以无需计算过程。\n",
    "\n",
    "此处略\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
